{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3ea31db8-a39e-4667-90e9-b7e5edbfa51a",
   "metadata": {},
   "source": [
    "# Build Multilingual Financial Search Applications with Cohere - Code Walkthrough\n",
    "In the following use case example, we’ll showcase how Cohere’s Embed model can search and\n",
    "query across financial news in different languages in one unique pipeline. Finally, we’ll see how\n",
    "adding Rerank to our embeddings retrieval (or adding it to a legacy lexical search) can further\n",
    "improve our results."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "734955e4-9fd1-4fb2-9839-c76b22dad070",
   "metadata": {},
   "source": [
    "### Step 0: Enable Model Access Through Amazon Bedrock"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5141dcea",
   "metadata": {},
   "source": [
    "Enable Model access through the [Amazon Console](https://console.aws.amazon.com/bedrock) following the instructions in the [Amazon Bedrock Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html) console.\n",
    "\n",
    "For this walkthrough you will need to request access to the Cohere Embed Multilingual model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "494c504f-0886-4b29-8716-0a47e6450647",
   "metadata": {},
   "source": [
    "### Step 1: Install Packages and Import Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb2388da",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade cohere-aws hnswlib\n",
    "# If you upgrade the package, you need to restart the kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be54602a-944d-43c4-a611-6f8f57285618",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import cohere_aws\n",
    "import hnswlib\n",
    "import warnings\n",
    "import os\n",
    "import re\n",
    "import boto3\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab9e4080-542b-4aad-94f4-6cf33789c792",
   "metadata": {},
   "source": [
    "### Step 2: Import Documents \n",
    "\n",
    "Information about MultiFIN paper and data can be found in its Github repo https://github.com/RasmusKaer/MultiFin.\n",
    "\n",
    "We will be using a csv that contains the data plus google translations of the articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94c04e40-e8ca-4273-b4df-d5b8e01990d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the data\n",
    "url = \"https://raw.githubusercontent.com/cohere-ai/cohere-aws/main/notebooks/bedrock/multiFIN_train.csv\"\n",
    "df = pd.read_csv(url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7197a0b7-f01e-46bd-ad2d-49395d119cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspect dataset\n",
    "df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "327177dd-bc70-4259-b3d4-4b7a70eedd1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check language distribution\n",
    "df['lang'].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0facfeb0-92e9-4709-b69f-1f2f60db7c43",
   "metadata": {},
   "source": [
    "### Step 3. Select List of Documents to Query\n",
    "\n",
    "We need to do a quick cleaning and then we will select the articles we will be querying"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c96470b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We want to select the longest articles, but some are long just due to repeated text - we will clean that up\n",
    "df['text'].iloc[2215]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcbafe7d-db5b-4f3d-8234-0ec577761cdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure there is no duplicated text in the headers\n",
    "def remove_duplicates(text):\n",
    "    return re.sub(r'((\\b\\w+\\b.{1,2}\\w+\\b)+).+\\1', r'\\1', text, flags=re.I)\n",
    "\n",
    "df ['text'] = df['text'].apply(remove_duplicates)\n",
    "\n",
    "# Keep only selected languages\n",
    "languages = ['English', 'Spanish', 'Danish']\n",
    "df = df.loc[df['lang'].isin(languages)]\n",
    "\n",
    "# Pick the top 80 longest articles\n",
    "df['text_length'] = df['text'].str.len()\n",
    "df.sort_values(by=['text_length'], ascending=False, inplace=True)\n",
    "top_80_df = df[:80]\n",
    "\n",
    "# Language distribution\n",
    "top_80_df['lang'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41bcc6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# As an example below is our longest article\n",
    "top_80_df['text'].iloc[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc6f9061-96a4-4d22-af35-8c2f9d9ae19a",
   "metadata": {},
   "source": [
    "### Step 4: Embed and Index Documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41fa6476-9cff-4af0-9d03-50213d72a2e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Establish Cohere client\n",
    "co = cohere_aws.Client(mode=cohere_aws.Mode.BEDROCK)\n",
    "model_id = \"cohere.embed-multilingual-v3\"\n",
    "\n",
    "# Embed documents\n",
    "docs = top_80_df['text'].to_list()\n",
    "docs_lang = top_80_df['lang'].to_list()\n",
    "translated_docs = top_80_df['translation'].to_list() #for reference when returning non-English results\n",
    "doc_embs = co.embed(texts=docs, model_id=model_id, input_type='search_document').embeddings\n",
    "\n",
    "# Create a search index with hnswlib, a library for fast approximate nearest neighbor search\n",
    "index = hnswlib.Index(space='ip', dim=1024) # Cohere.embed-multilingual-v3 outputs embeddings with 1024 dimensions\n",
    "index.init_index(max_elements=len(doc_embs), ef_construction=512, M=64) # For more info: https://github.com/nmslib/hnswlib#api-description\n",
    "index.add_items(doc_embs, list(range(len(doc_embs))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4de25ec1-98f6-4a07-819a-ac40c3324cd3",
   "metadata": {},
   "source": [
    "### Step 5: Build a Retrieval System"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29843367-b738-4eb5-83eb-0bb5827b3a79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieval of 3 closest docs to query\n",
    "def retrieval(query):\n",
    "    # Embed query and retrieve results\n",
    "    query_emb = co.embed(texts=[query], model_id=model_id, input_type=\"search_query\").embeddings\n",
    "    doc_ids = index.knn_query(query_emb, k=3)[0][0] # we will retrieve 3 closest neighbors\n",
    "    \n",
    "    # Print and append results\n",
    "    print(f\"QUERY: {query.upper()} \\n\")\n",
    "    retrieved_docs, translated_retrieved_docs = [], []\n",
    "    \n",
    "    for doc_id in doc_ids:\n",
    "        # Append results\n",
    "        retrieved_docs.append(docs[doc_id])\n",
    "        translated_retrieved_docs.append(translated_docs[doc_id])\n",
    "    \n",
    "        # Print results\n",
    "        print(f\"ORIGINAL ({docs_lang[doc_id]}): {docs[doc_id]}\")\n",
    "        if docs_lang[doc_id] != \"English\":\n",
    "            print(f\"TRANSLATION: {translated_docs[doc_id]} \\n----\")\n",
    "        else:\n",
    "            print(\"----\")\n",
    "    print(\"END OF RESULTS \\n\\n\")\n",
    "    return retrieved_docs, translated_retrieved_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b924bd1d-930a-4624-8ad6-cc0eef8c5b07",
   "metadata": {},
   "source": [
    "### Step 6: Query the Retrieval System"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa6a8d7f-9cfe-4c58-8153-f148aaca89d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = [\n",
    "    \"Are businesses meeting sustainability goals?\",\n",
    "    \"Can data science help meet sustainability goals?\"\n",
    "]\n",
    "\n",
    "for query in queries:\n",
    "    retrieval(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8b6191-d218-4f4c-bc97-e9e3493113f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"Hvor kan jeg finde den seneste danske boligplan?\" # \"Where can I find the latest Danish property plan?\"\n",
    "retrieved_docs, translated_retrieved_docs = retrieval(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5605c59-0ef4-48f5-8dc9-4cbc7b95ec5d",
   "metadata": {},
   "source": [
    "### Step 7: Improve Results with Cohere Rerank\n",
    "\n",
    "The following query is not returning the most relevant result at the top, here is where Rerank will help."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e22c0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"Are companies ready for the next down market?\"\n",
    "retrieved_docs, translated_retrieved_docs = retrieval(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbe1bd3c",
   "metadata": {},
   "source": [
    "####  Subscribe to the model package in SageMaker\n",
    "\n",
    "\n",
    "Rerank is available in SageMaker.\n",
    "\n",
    "\n",
    "To subscribe to the model package:\n",
    "1. Open the model package listing page [cohere-rerank-multilingual](https://aws.amazon.com/marketplace/pp/prodview-pf7d2umihcseq)\n",
    "1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.\n",
    "1. On the **Subscribe to this software** page, review and click on **\"Accept Offer\"** if you and your organization agrees with EULA, pricing, and support terms. \n",
    "1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450c314a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pap model package arn\n",
    "import boto3\n",
    "cohere_package = \"cohere-rerank-multilingual-v2--8b26a507962f3adb98ea9ac44cb70be1\" # replace this with your info\n",
    "\n",
    "model_package_map = {\n",
    "    \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:model-package/{cohere_package}\",\n",
    "    \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:model-package/{cohere_package}\",\n",
    "    \"us-west-1\": f\"arn:aws:sagemaker:us-west-1:382657785993:model-package/{cohere_package}\",\n",
    "    \"us-west-2\": f\"arn:aws:sagemaker:us-west-2:594846645681:model-package/{cohere_package}\",\n",
    "    \"ca-central-1\": f\"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{cohere_package}\",\n",
    "    \"eu-central-1\": f\"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{cohere_package}\",\n",
    "    \"eu-west-1\": f\"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{cohere_package}\",\n",
    "    \"eu-west-2\": f\"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{cohere_package}\",\n",
    "    \"eu-west-3\": f\"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{cohere_package}\",\n",
    "    \"eu-north-1\": f\"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{cohere_package}\",\n",
    "    \"ap-southeast-1\": f\"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{cohere_package}\",\n",
    "    \"ap-southeast-2\": f\"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{cohere_package}\",\n",
    "    \"ap-northeast-2\": f\"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{cohere_package}\",\n",
    "    \"ap-northeast-1\": f\"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}\",\n",
    "    \"ap-south-1\": f\"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{cohere_package}\",\n",
    "    \"sa-east-1\": f\"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{cohere_package}\",\n",
    "}\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "if region not in model_package_map.keys():\n",
    "    raise Exception(f\"Current boto3 session region {region} is not supported.\")\n",
    "\n",
    "model_package_arn = model_package_map[region]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6dfc159",
   "metadata": {},
   "source": [
    "#### Create an endpoint and perform real-time inference\n",
    "\n",
    "If you want to understand how real-time inference with Amazon SageMaker works, see [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad714c74",
   "metadata": {},
   "outputs": [],
   "source": [
    "co = cohere_aws.Client(region_name=region)\n",
    "co.create_endpoint(arn=model_package_arn, endpoint_name=\"cohere-rerank-multilingual\", instance_type=\"ml.g4dn.xlarge\", n_instances=1)\n",
    "\n",
    "# If the endpoint is already created, you just need to connect to it\n",
    "# co.connect_to_endpoint(endpoint_name=\"cohere-rerank-multilingual\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32facc5b",
   "metadata": {},
   "source": [
    "Once endpoint has been created, you would be able to perform real-time inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7ee9ff-03cd-4140-b972-768a67872c51",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = co.rerank(query=query, documents=retrieved_docs, top_n=1)\n",
    "\n",
    "for hit in results:\n",
    "    print(hit.document['text'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
